{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a Monty Hall Bayesian Network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "authors:<br>\n",
    "Jacob Schreiber [<a href=\"mailto:jmschreiber91@gmail.com\">jmschreiber91@gmail.com</a>]<br>\n",
    "Nicholas Farn [<a href=\"mailto:nicholasfarn@gmail.com\">nicholasfarn@gmail.com</a>]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets test out the Bayesian Network framework to produce the Monty Hall problem, but modified a little. The Monty Hall problem is basically a game show where a guest chooses one of three doors to open, with an unknown one having a prize behind it. Monty then opens another non-chosen door without a prize behind it, and asks the guest if they would like to change their answer. Many people were surprised to find that if the guest changed their answer, there was a 66% chance of success as opposed to a 50% as might be expected if there were two doors.\n",
    "\n",
    "This can be modelled as a Bayesian network with three nodes-- guest, prize, and Monty, each over the domain of door 'A', 'B', 'C'. Monty is dependent on both guest and prize, in that it can't be either of them. Lets extend this a little bit to say the guest has an untrustworthy friend whose answer he will not go with."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "from pomegranate import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create the distributions for the guest and the prize. Note that both distributions are independent of one another."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "guest = DiscreteDistribution( { 'A': 1./3, 'B': 1./3, 'C': 1./3 } )\n",
    "prize = DiscreteDistribution( { 'A': 1./3, 'B': 1./3, 'C': 1./3 } )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's create the conditional probability table for our Monty. The table is dependent on both the guest and the prize."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "monty = ConditionalProbabilityTable(\n",
    "\t[[ 'A', 'A', 'A', 0.0 ],\n",
    "\t [ 'A', 'A', 'B', 0.5 ],\n",
    "\t [ 'A', 'A', 'C', 0.5 ],\n",
    "\t [ 'A', 'B', 'A', 0.0 ],\n",
    "\t [ 'A', 'B', 'B', 0.0 ],\n",
    "\t [ 'A', 'B', 'C', 1.0 ],\n",
    "\t [ 'A', 'C', 'A', 0.0 ],\n",
    "\t [ 'A', 'C', 'B', 1.0 ],\n",
    "\t [ 'A', 'C', 'C', 0.0 ],\n",
    "\t [ 'B', 'A', 'A', 0.0 ],\n",
    "\t [ 'B', 'A', 'B', 0.0 ],\n",
    "\t [ 'B', 'A', 'C', 1.0 ],\n",
    "\t [ 'B', 'B', 'A', 0.5 ],\n",
    "\t [ 'B', 'B', 'B', 0.0 ],\n",
    "\t [ 'B', 'B', 'C', 0.5 ],\n",
    "\t [ 'B', 'C', 'A', 1.0 ],\n",
    "\t [ 'B', 'C', 'B', 0.0 ],\n",
    "\t [ 'B', 'C', 'C', 0.0 ],\n",
    "\t [ 'C', 'A', 'A', 0.0 ],\n",
    "\t [ 'C', 'A', 'B', 1.0 ],\n",
    "\t [ 'C', 'A', 'C', 0.0 ],\n",
    "\t [ 'C', 'B', 'A', 1.0 ],\n",
    "\t [ 'C', 'B', 'B', 0.0 ],\n",
    "\t [ 'C', 'B', 'C', 0.0 ],\n",
    "\t [ 'C', 'C', 'A', 0.5 ],\n",
    "\t [ 'C', 'C', 'B', 0.5 ],\n",
    "\t [ 'C', 'C', 'C', 0.0 ]], [guest, prize] )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now lets create the states for the bayesian network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "s1 = State( guest, name=\"guest\" )\n",
    "s2 = State( prize, name=\"prize\" )\n",
    "s3 = State( monty, name=\"monty\" )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then the bayesian network itself, adding the states in after."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "network = BayesianNetwork( \"test\" )\n",
    "network.add_states( s1, s2, s3 )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then the transitions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "network.add_transition( s1, s3 )\n",
    "network.add_transition( s2, s3 )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With a \"bake\" to finalize the structure of our network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "network.bake()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can train our network on the following data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{\n",
       "    \"class\" : \"BayesianNetwork\",\n",
       "    \"name\" : \"test\",\n",
       "    \"structure\" : [\n",
       "        [],\n",
       "        [],\n",
       "        [\n",
       "            0,\n",
       "            1\n",
       "        ]\n",
       "    ],\n",
       "    \"states\" : [\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"dtype\" : \"str\",\n",
       "                \"name\" : \"DiscreteDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    {\n",
       "                        \"A\" : 0.4166666666666667,\n",
       "                        \"B\" : 0.16666666666666666,\n",
       "                        \"C\" : 0.4166666666666667\n",
       "                    }\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"guest\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"dtype\" : \"str\",\n",
       "                \"name\" : \"DiscreteDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    {\n",
       "                        \"A\" : 0.4166666666666667,\n",
       "                        \"B\" : 0.25,\n",
       "                        \"C\" : 0.3333333333333333\n",
       "                    }\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"prize\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"ConditionalProbabilityTable\",\n",
       "                \"table\" : [\n",
       "                    [\n",
       "                        \"A\",\n",
       "                        \"A\",\n",
       "                        \"A\",\n",
       "                        \"1.0\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"A\",\n",
       "                        \"A\",\n",
       "                        \"B\",\n",
       "                        \"0.0\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"A\",\n",
       "                        \"A\",\n",
       "                        \"C\",\n",
       "                        \"0.0\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"A\",\n",
       "                        \"B\",\n",
       "                        \"A\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"A\",\n",
       "                        \"B\",\n",
       "                        \"B\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"A\",\n",
       "                        \"B\",\n",
       "                        \"C\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"A\",\n",
       "                        \"C\",\n",
       "                        \"A\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"A\",\n",
       "                        \"C\",\n",
       "                        \"B\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"A\",\n",
       "                        \"C\",\n",
       "                        \"C\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"B\",\n",
       "                        \"A\",\n",
       "                        \"A\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"B\",\n",
       "                        \"A\",\n",
       "                        \"B\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"B\",\n",
       "                        \"A\",\n",
       "                        \"C\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"B\",\n",
       "                        \"B\",\n",
       "                        \"A\",\n",
       "                        \"0.0\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"B\",\n",
       "                        \"B\",\n",
       "                        \"B\",\n",
       "                        \"0.5\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"B\",\n",
       "                        \"B\",\n",
       "                        \"C\",\n",
       "                        \"0.5\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"B\",\n",
       "                        \"C\",\n",
       "                        \"A\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"B\",\n",
       "                        \"C\",\n",
       "                        \"B\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"B\",\n",
       "                        \"C\",\n",
       "                        \"C\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"C\",\n",
       "                        \"A\",\n",
       "                        \"A\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"C\",\n",
       "                        \"A\",\n",
       "                        \"B\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"C\",\n",
       "                        \"A\",\n",
       "                        \"C\",\n",
       "                        \"0.3333333333333333\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"C\",\n",
       "                        \"B\",\n",
       "                        \"A\",\n",
       "                        \"1.0\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"C\",\n",
       "                        \"B\",\n",
       "                        \"B\",\n",
       "                        \"0.0\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"C\",\n",
       "                        \"B\",\n",
       "                        \"C\",\n",
       "                        \"0.0\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"C\",\n",
       "                        \"C\",\n",
       "                        \"A\",\n",
       "                        \"0.25\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"C\",\n",
       "                        \"C\",\n",
       "                        \"B\",\n",
       "                        \"0.0\"\n",
       "                    ],\n",
       "                    [\n",
       "                        \"C\",\n",
       "                        \"C\",\n",
       "                        \"C\",\n",
       "                        \"0.75\"\n",
       "                    ]\n",
       "                ],\n",
       "                \"dtypes\" : [\n",
       "                    \"str\",\n",
       "                    \"str\",\n",
       "                    \"str\",\n",
       "                    \"float\"\n",
       "                ],\n",
       "                \"parents\" : [\n",
       "                    {\n",
       "                        \"class\" : \"Distribution\",\n",
       "                        \"dtype\" : \"str\",\n",
       "                        \"name\" : \"DiscreteDistribution\",\n",
       "                        \"parameters\" : [\n",
       "                            {\n",
       "                                \"A\" : 0.4166666666666667,\n",
       "                                \"B\" : 0.16666666666666666,\n",
       "                                \"C\" : 0.4166666666666667\n",
       "                            }\n",
       "                        ],\n",
       "                        \"frozen\" : false\n",
       "                    },\n",
       "                    {\n",
       "                        \"class\" : \"Distribution\",\n",
       "                        \"dtype\" : \"str\",\n",
       "                        \"name\" : \"DiscreteDistribution\",\n",
       "                        \"parameters\" : [\n",
       "                            {\n",
       "                                \"A\" : 0.4166666666666667,\n",
       "                                \"B\" : 0.25,\n",
       "                                \"C\" : 0.3333333333333333\n",
       "                            }\n",
       "                        ],\n",
       "                        \"frozen\" : false\n",
       "                    }\n",
       "                ]\n",
       "            },\n",
       "            \"name\" : \"monty\",\n",
       "            \"weight\" : 1.0\n",
       "        }\n",
       "    ]\n",
       "}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = [[ 'A', 'A', 'A' ],\n",
    "\t\t[ 'A', 'A', 'A' ],\n",
    "\t\t[ 'A', 'A', 'A' ],\n",
    "\t\t[ 'A', 'A', 'A' ],\n",
    "\t\t[ 'A', 'A', 'A' ],\n",
    "\t\t[ 'B', 'B', 'B' ],\n",
    "\t\t[ 'B', 'B', 'C' ],\n",
    "\t\t[ 'C', 'C', 'A' ],\n",
    "\t\t[ 'C', 'C', 'C' ],\n",
    "\t\t[ 'C', 'C', 'C' ],\n",
    "\t\t[ 'C', 'C', 'C' ],\n",
    "\t\t[ 'C', 'B', 'A' ]]\n",
    "\n",
    "network.fit( data )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's see what happens when our Guest says 'A' and the Prize is 'A'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "guest\tA\n",
      "prize\tA\n",
      "monty\t{\n",
      "    \"class\" :\"Distribution\",\n",
      "    \"dtype\" :\"str\",\n",
      "    \"name\" :\"DiscreteDistribution\",\n",
      "    \"parameters\" :[\n",
      "        {\n",
      "            \"A\" :1.0,\n",
      "            \"B\" :0.0,\n",
      "            \"C\" :0.0\n",
      "        }\n",
      "    ],\n",
      "    \"frozen\" :false\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "observations = { 'guest' : 'A', 'prize' : 'A' }\n",
    "beliefs = map( str, network.predict_proba( observations ) )\n",
    "print(\"\\n\".join( \"{}\\t{}\".format( state.name, belief ) for state, belief in zip( network.states, beliefs ) ))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
